# Copyright 2020 The TensorStore Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""setuptools script for installing the Python package.

This invokes bazel via the included `bazelisk.py` wrapper script.
"""

import sys

if sys.version_info < (3, 10):
  print('Python >= 3.10 is required to build')
  sys.exit(1)

# Import setuptools before distutils because setuptools monkey patches
# distutils:
#
# https://github.com/pypa/setuptools/commit/bd1102648109c85c782286787e4d5290ae280abe
import setuptools

import atexit
import distutils.cmd
import distutils.command.build
import os
import pathlib
import shlex
import shutil
import subprocess
import sysconfig
import tempfile

import setuptools.command.build_ext
import setuptools.command.build_py
import setuptools.command.install
import setuptools.command.sdist
import setuptools.dist

_REPO_ROOT = os.path.dirname(__file__)


def _setup_temp_egg_info(cmd: distutils.cmd.Command) -> None:
  """Use a temporary directory for the `.egg-info` directory.

  When building an sdist (source distribution) or installing, locate the
  `.egg-info` directory inside a temporary directory so that it
  doesn't litter the source directory and doesn't pick up a stale SOURCES.txt
  from a previous build.

  Args:
    cmd: The setuptools command.
  """
  egg_info_cmd = cmd.distribution.get_command_obj('egg_info')
  if egg_info_cmd.egg_base is None:
    tempdir = tempfile.TemporaryDirectory(dir=os.curdir)
    egg_info_cmd.egg_base = tempdir.name
    atexit.register(tempdir.cleanup)


def _parse_version(version_str: str):
  return tuple(int(x) for x in version_str.split('.'))


class SdistCommand(setuptools.command.sdist.sdist):
  """Overrides default sdist command to exclude .egg-info."""

  def run(self):
    _setup_temp_egg_info(self)
    super().run()

  def make_release_tree(self, base_dir, files):
    # Exclude .egg-info from source distribution.  These aren't actually
    # needed, and due to the use of the temporary directory in `run`, the
    # path isn't correct if it gets included.
    files = [x for x in files if '.egg-info' not in x]
    super().make_release_tree(base_dir, files)


class BuildCommand(distutils.command.build.build):
  """Overrides default build command to use a temporary directory."""

  def finalize_options(self):
    if self.build_base == 'build':
      # Use temporary directory instead, to avoid littering the source directory
      # with a `build` sub-directory.
      tempdir = tempfile.TemporaryDirectory()
      self.build_base = tempdir.name
      atexit.register(tempdir.cleanup)
    super().finalize_options()


_EXCLUDED_PYTHON_MODULES = frozenset([
    'tensorstore.bazel_pytest_main',
    'tensorstore.shell',
    'tensorstore.cc_test_driver_main',
    'tensorstore.generate_type_stubs',
])


def _include_python_module(name):
  if name in _EXCLUDED_PYTHON_MODULES:
    return False
  if name.endswith('_test'):
    return False
  return True


class BuildPyCommand(setuptools.command.build_py.build_py):
  """Overrides default build_py command to exclude files."""

  def find_package_modules(self, package, package_dir):
    modules = super().find_package_modules(package, package_dir)
    return [
        (pkg, mod, path)
        for (pkg, mod, path) in modules
        if _include_python_module('%s.%s' % (pkg, mod))
    ]


def _configure_macos_deployment_target() -> str:
  """Configures the MACOSX_DEPLOYMENT_TARGET environment variable.

  Returns:
    The configured value.
  """

  # TensorStore requires MACOSX_DEPLOYMENT_TARGET >= 10.14 in
  # order to support sized/aligned operator new/delete.
  min_macos_target = '10.14'
  key = 'MACOSX_DEPLOYMENT_TARGET'

  macos_target = min_macos_target

  # The specific python version may have a different target.
  # Use that if it is higher than the tensorstore minimum.
  sysconfig_macos_target = str(sysconfig.get_config_var(key))
  if sysconfig_macos_target and _parse_version(
      sysconfig_macos_target
  ) >= _parse_version(macos_target):
    macos_target = sysconfig_macos_target

  # The environment variable may be set by the user.
  # Use that if it is higher than the tensorstore minimum, otherwise log.
  env_macos_target = os.getenv(key)
  if env_macos_target:
    if _parse_version(env_macos_target) >= _parse_version(macos_target):
      macos_target = env_macos_target
    else:
      print(
          f'WARNING: {key}={env_macos_target} is set in environment but '
          f' >={min_macos_target} is required by this package and '
          f' >={sysconfig_macos_target} is required by the current Python build'
      )

  # Set MACOSX_DEPLOYMENT_TARGET in the environment, because the `wheel` package
  # checks there.  Note that Bazel receives the version via a command-line
  # option instead.
  os.environ[key] = macos_target
  return macos_target


def _remove_virtual_env_paths(parts: list[str], venv_dir: str) -> list[str]:
  """Removes from `parts` any paths that start with `venv_dir`.

  Args:
    parts: The list of path parts to modify.
    venv_dir: The directory to remove.

  Returns:
    The modified list.
  """
  venv_dir = os.path.normpath(venv_dir)
  return [
      part
      for part in parts
      if not pathlib.Path(os.path.normpath(part)).is_relative_to(venv_dir)
  ]


def _normalize_path(path: str) -> str:
  """Normalize the PATH used when invoking Bazel.

  `pip install .` creates a temporary build environment, and adjusts the PATH
  to add a binary path and an overlay path. As bazel is already hermetic, the
  goal is to remove the build environment paths; successful removal will allow
  bazel to improve cache sharing.

  Unfortunately there is no kosher way to detect the PEP517 build environment,
  so this is a heuristic approach based on inspection of the PATH variable
  passed to the build and review of the PIP sources. The source, as reviewed,
  creates a temporary directory including pip-{kind}, where kind=build-env.

  When using `uv build`, the `VIRTUAL_ENV` environment variable is set
  to the temporary build environment, which can be detected directly.
  easily.

  See: dist-packages/pip/_internal/utils/temp_dir.py
  Also: https://github.com/bazelbuild/bazel/issues/18809

  Args:
    path: The PATH to normalize.

  Returns:
    The normalized PATH.
  """

  parts = path.split(os.pathsep)

  # Remove any VIRTUAL_ENV paths if present.
  #
  # uv build sets `VIRTUAL_ENV` for its temporary build environment.
  if (virtual_env := os.getenv('VIRTUAL_ENV', None)) is not None:
    parts = _remove_virtual_env_paths(parts, virtual_env)

  # Remove pip temporary build environment, if present.
  #
  # pip does not set VIRTUAL_ENV; instead it must be detected
  # heuristically.
  build_env = None

  for x in parts:
    if 'pip-build-env' in x:
      build_env = x
      break
  if build_env is not None:
    # There may be multiple path entries added under the build-env directory,
    # so remove them all.
    while 'pip-build-env' in os.path.dirname(build_env):
      build_env = os.path.dirname(build_env)

    parts = _remove_virtual_env_paths(parts, build_env)

  return os.pathsep.join(parts)


if 'darwin' in sys.platform:
  _macos_deployment_target = _configure_macos_deployment_target()


class BuildExtCommand(setuptools.command.build_ext.build_ext):
  """Overrides default build_ext command to invoke bazel."""

  def run(self):
    ext = self.extensions[0]
    ext_full_path = self.get_ext_fullpath(ext.name)

    prebuilt_path = os.getenv('TENSORSTORE_PREBUILT_DIR')
    if not prebuilt_path:

      env = os.environ.copy()

      # Bazel cache includes PATH; attempt to remove the pip build-env
      # from the PATH as bazel is already hermetic to improve cache use.
      if 'PATH' in env:
        env['PATH'] = _normalize_path(env['PATH'])

      # Ensure python_configure.bzl finds the correct Python version.
      env['PYTHON_BIN_PATH'] = sys.executable

      bazelisk = os.getenv('TENSORSTORE_BAZELISK', 'bazelisk.py')
      # Controlled via `setup.py build_ext --debug` flag.
      default_compilation_mode = 'dbg' if self.debug else 'opt'
      startup_options = shlex.split(
          os.getenv('TENSORSTORE_BAZEL_STARTUP_OPTIONS', '')
      )
      build_options = shlex.split(
          os.getenv('TENSORSTORE_BAZEL_BUILD_OPTIONS', '')
      )

      # Build with a specific compilation mode.
      # When in opt mode also set -O3 optimizations to override bazel -O2.
      compilation_mode = os.getenv(
          'TENSORSTORE_BAZEL_COMPILATION_MODE', default_compilation_mode
      )
      build_flags = ['build', '-c', compilation_mode]
      if compilation_mode == 'opt':
        if 'win32' in sys.platform:
          # Assumes MSVC compiler.
          build_flags.append('--copt=/Ox')
        else:
          build_flags.append('--copt=-O3')

      version_stamp = self.distribution.get_version()
      if version_stamp:
        build_flags.append(
            '--//tensorstore/internal/version=%s' % version_stamp
        )

      if sysconfig.get_config_var('Py_GIL_DISABLED') == '1':
        build_flags.append(
            '--@rules_python//python/config_settings:py_freethreaded=yes'
        )

      build_command = (
          [sys.executable, '-u', bazelisk]
          + startup_options
          + build_flags
          + [
              '//python/tensorstore:_tensorstore__shared_objects',
              '--verbose_failures',
              # Bazel does not seem to download these files by default when
              # using remote caching.
              r'--remote_download_regex=.*/(_tensorstore\.(so|pyd))',
          ]
          + build_options
      )
      if 'darwin' in sys.platform:
        # Note: Bazel does not use the MACOSX_DEPLOYMENT_TARGET environment
        # variable.
        build_command += ['--macos_minimum_os=%s' % _macos_deployment_target]
        # Support cross-compilation on macOS
        # https://github.com/pypa/cibuildwheel/discussions/997#discussioncomment-2045760
        darwin_cpus = [
            x for x in os.getenv('ARCHFLAGS', '').split() if x != '-arch'
        ]
        # cibuildwheel sets `ARCHFLAGS` to one of:
        #     '-arch x86_64'
        #     '-arch arm64'
        #     '-arch arm64 -arch x86_64'
        if darwin_cpus:
          if len(darwin_cpus) > 1:
            raise ValueError(
                'Fat/universal %r build not supported' % (darwin_cpus,)
            )
          darwin_cpu = darwin_cpus[0]
          build_command += [
              f'--cpu=darwin_{darwin_cpu}',
              f'--macos_cpus={darwin_cpu}',
          ]
      if sys.platform == 'win32':
        # Disable the constexpr mutex constructor which is incompatible with
        # older versions of msvcp140.dll.
        #
        # https://developercommunity.visualstudio.com/t/Access-violation-with-std::mutex::lock-a/10664660#T-N10668856
        build_command += ['--copt=/D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR']

        # Disable newer exception handling from Visual Studio 2019, since it
        # requires a newer C++ runtime than shipped with Python.
        #
        # https://cibuildwheel.readthedocs.io/en/stable/faq/#importerror-dll-load-failed-the-specific-module-could-not-be-found-error-on-windows
        build_command += ['--copt=/d2FH4-']
      else:
        # Build with hidden visibility for more efficient code generation.
        # Note that this also hides most symbols, but ultimately has no effect
        # on symbol visibility because a separate linker option is already
        # used to hide all extraneous symbols anyway.
        build_command += ['--copt=-fvisibility=hidden']

      print(f'Executing {shlex.join(build_command)} with env {env}')

      if not self.dry_run:
        subprocess.check_call(build_command, env=env)

      suffix = '.pyd' if os.name == 'nt' else '.so'
      built_ext_path = os.path.join(
          'bazel-bin/python/tensorstore/_tensorstore' + suffix
      )
    else:
      # If `TENSORSTORE_PREBUILT_DIR` is set, the extension module is assumed
      # to have already been built a prior call to `build_ext -b
      # $TENSORSTORE_PREBUILT_DIR`.
      #
      # This is used in conjunction with cibuildwheel to first perform an
      # in-tree build of the extension module in order to take advantage of
      # Bazel caching:
      #
      # https://github.com/pypa/pip/pull/9091
      # https://github.com/joerick/cibuildwheel/issues/486
      built_ext_path = os.path.join(
          prebuilt_path, 'tensorstore', os.path.basename(ext_full_path)
      )

    if self.dry_run:
      print('Dry run, skipping build')
      return

    os.makedirs(os.path.dirname(ext_full_path), exist_ok=True)
    print(
        'Copying extension %s -> %s'
        % (
            built_ext_path,
            ext_full_path,
        )
    )
    shutil.copyfile(built_ext_path, ext_full_path)


class InstallCommand(setuptools.command.install.install):

  def run(self):
    _setup_temp_egg_info(self)
    super().run()


setuptools.setup(
    packages=setuptools.find_packages('python'),
    package_dir={'': 'python'},
    exclude_package_data={'': ['*.h', '*.cc', '*.py', 'BUILD*']},
    ext_modules=[setuptools.Extension('tensorstore/_tensorstore', sources=[])],
    setup_requires=['setuptools_scm>=8.1.0'],
    cmdclass={
        'sdist': SdistCommand,
        'build': BuildCommand,
        'build_py': BuildPyCommand,
        'build_ext': BuildExtCommand,
        'install': InstallCommand,
    },
)
